import pandas as pd
import statsmodels.formula.api as smf
import numpy as np

# Load the cleaned dataset
df = pd.read_csv('cleaned_with_male_head.csv')

# Combine regions as in the paper
df['region_grouped'] = df['region']
df.loc[df['region_grouped'].isin(['Northern Midlands and Mountains', 'Northwest']), 'region_grouped'] = 'Northern Midlands and Mountainous Area'
df.loc[df['region_grouped'].isin(['North Central Coast', 'South Central Coast']), 'region_grouped'] = 'North Central & Central Coastal Area'

# Define the regions as in the paper
regions = [
    'Red River Delta',
    'Northern Midlands and Mountainous Area',
    'North Central & Central Coastal Area',
    'Central Highlands',
    'Mekong River Delta'
]

# Prepare data (ensure categorical variables)
df['CHILD'] = df['CHILD'].astype('category')
df['year'] = df['year'].astype('category')

# Demean function for household fixed effects
def demean(df, vars_to_demean, group_col='code_id_hh'):
    df_demean = df.copy()
    for var in vars_to_demean:
        if var in df_demean.columns:
            group_mean = df_demean.groupby(group_col)[var].transform('mean')
            df_demean[var] = df_demean[var] - group_mean
    return df_demean

# Dictionary to store results
results = {}

# Run regression for each region
for reg in regions:
    df_reg = df[df['region_grouped'] == reg].copy()
    if df_reg.empty:
        print(f"No data for region: {reg}")
        continue
    
    # Demean variables for household FE
    vars_to_demean = ['fert', 'htype', 'mar', 'emp', 'pri', 'sec', 'high', 'inc', 'fam']
    df_reg_demean = demean(df_reg, vars_to_demean)
    
    # Model: OLS with demeaning (FEM approx), year FE, and clustered SE
    formula = 'fert ~ htype + C(CHILD) + mar + emp + pri + sec + high + inc + fam + C(year)'
    model = smf.ols(formula, data=df_reg_demean).fit(cov_type='cluster', cov_kwds={'groups': df_reg_demean['code_id_hh']})
    
    # Extract key stats for 'htype'
    coeff = model.params.get('htype', np.nan)
    se = model.bse.get('htype', np.nan)
    
    ci_low = coeff - 1.96 * se
    ci_high = coeff + 1.96 * se
    
    results[reg] = {
        'coeff': round(coeff, 3),
        'ci_low': round(ci_low, 3),
        'ci_high': round(ci_high, 3),
        'observations': int(model.nobs),
        'households': df_reg['code_id_hh'].nunique(),
        'rsquared': round(model.rsquared, 3)
    }

# Create the table as a DataFrame
table6_data = {
    'Region': regions,
    'Living with grandparent(s) ^a': [f"{results.get(reg, {}).get('coeff', np.nan):.3f} ({results.get(reg, {}).get('ci_low', np.nan):.3f} to {results.get(reg, {}).get('ci_high', np.nan):.3f})" for reg in regions],
    'Observations': [results.get(reg, {}).get('observations', np.nan) for reg in regions],
    'Number of households': [results.get(reg, {}).get('households', np.nan) for reg in regions],
    'R-squared': [results.get(reg, {}).get('rsquared', np.nan) for reg in regions]
}
table6 = pd.DataFrame(table6_data)
print("\nRevised Table 6: Impact of living with grandparent(s) by region\n")
print(table6)
table6.to_csv('revise_table_6_region.csv', index=False)
print("Saved revised_table6_region_impact.csv")
print("Notes: 95% confidence intervals in parentheses. ^a Reference = Not living with grandparent(s).")